# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
import warnings

import cv2
import numpy as np
import paddle
from PIL import Image
from tqdm.auto import trange

from ppdiffusers import (
    FlowMatchEulerDiscreteScheduler,
    DDIMScheduler,
    DDPMScheduler,
    DEISMultistepScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    KDPM2DiscreteScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
    StableDiffusion3Pipeline,
    UniPCMultistepScheduler,
)
from ppdiffusers.utils import load_image



def strtobool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ValueError(
            f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
        )


def change_scheduler(self, scheduler_type="ddim"):
    self.original_scheduler_config = self.scheduler.config
    scheduler_type = scheduler_type.lower()
    if scheduler_type == "flow":
        scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.original_scheduler_config, skip_prk_steps=True)
    elif scheduler_type == "pndm":
        scheduler = PNDMScheduler.from_config(self.original_scheduler_config, skip_prk_steps=True)
    elif scheduler_type == "lms":
        scheduler = LMSDiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "heun":
        scheduler = HeunDiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "euler":
        scheduler = EulerDiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "euler-ancestral":
        scheduler = EulerAncestralDiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "dpm-multi":
        scheduler = DPMSolverMultistepScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "dpm-single":
        scheduler = DPMSolverSinglestepScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "kdpm2-ancestral":
        scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "kdpm2":
        scheduler = KDPM2DiscreteScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "unipc-multi":
        scheduler = UniPCMultistepScheduler.from_config(self.original_scheduler_config)
    elif scheduler_type == "ddim":
        scheduler = DDIMScheduler.from_config(
            self.original_scheduler_config,
            steps_offset=1,
            clip_sample=False,
            set_alpha_to_one=False,
        )
    elif scheduler_type == "ddpm":
        scheduler = DDPMScheduler.from_config(
            self.original_scheduler_config,
        )
    elif scheduler_type == "deis-multi":
        scheduler = DEISMultistepScheduler.from_config(
            self.original_scheduler_config,
        )
    else:
        raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
    return scheduler

def get_paddle_memory_info():
    """get_memory_info"""
    divisor = 2**30
    return (
        paddle.device.cuda.memory_allocated() / divisor,
        paddle.device.cuda.max_memory_allocated() / divisor,
        paddle.device.cuda.memory_reserved() / divisor,
        paddle.device.cuda.max_memory_reserved() / divisor,
    )
    
def parse_arguments():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="stabilityai/stable-diffusion-3-medium-diffusers",
        help="Path to the `diffusers` checkpoint to convert (either a local directory or on the bos).",
    )
    parser.add_argument(
        "--inference_steps",
        type=int,
        default=50,
        help="The number of unet inference steps.",
    )
    parser.add_argument(
        "--benchmark_steps",
        type=int,
        default=10,
        help="The number of performance benchmark steps.",
    )
    parser.add_argument(
        "--task_name",
        type=str,
        default="all",
        choices=[
            "text2img",
            "img2img",
            "inpaint_legacy",
            "all",
        ],
        help="The task can be one of [text2img, img2img, inpaint_legacy, all]. ",
    )
    parser.add_argument(
        "--parse_prompt_type",
        type=str,
        default="raw",
        choices=[
            "raw",
            "lpw",
        ],
        help="The parse_prompt_type can be one of [raw, lpw]. ",
    )
    parser.add_argument("--use_fp16", type=strtobool, default=True, help="Whether to use FP16 mode")
    parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id. -1 means use cpu")
    parser.add_argument(
        "--scheduler",
        type=str,
        default="euler-ancestral",
        choices=[
            "flow",
            "pndm",
            "lms",
            "euler",
            "euler-ancestral",
            "dpm-multi",
            "dpm-single",
            "unipc-multi",
            "ddim",
            "ddpm",
            "deis-multi",
            "heun",
            "kdpm2-ancestral",
            "kdpm2",
        ],
        help="The scheduler type of stable diffusion.",
    )
    parser.add_argument("--height", type=int, default=512, help="Height of input image")
    parser.add_argument("--width", type=int, default=512, help="Width of input image")
    parser.add_argument("--strength", type=float, default=1.0, help="Strength for img2img / inpaint")
    return parser.parse_args()


def main(args):

    seed = 1024
    paddle_dtype = paddle.float16 if args.use_fp16 else paddle.float32
    pipe = StableDiffusion3Pipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        safety_checker=None,
        feature_extractor=None,
        requires_safety_checker=False,
        paddle_dtype=paddle_dtype,
    )
    scheduler = change_scheduler(pipe, args.scheduler)
    pipe.scheduler = scheduler

    width = args.width
    height = args.height
    pipe.set_progress_bar_config(disable=False)

    folder = f"paddle_fp16" if args.use_fp16 else f"paddle_fp32"
    os.makedirs(folder, exist_ok=True)
    if args.task_name in ["text2img", "all"]:
        # text2img
        prompt = "bird"
        time_costs = []
        memory_metrics = []
        
        # warmup
        pipe(
            prompt,
            num_inference_steps=10,
            height=height,
            width=width,
        )
        print("==> Test text2img performance.")
        for step in trange(args.benchmark_steps):
            start = time.time()
            paddle.seed(seed)
            images = pipe(
                prompt,
                num_inference_steps=args.inference_steps,
                height=height,
                width=width,
            ).images
            latency = time.time() - start
            time_costs += [latency]
            
            # 收集显存信息
            memory_allocated, max_memory_allocated, memory_reserved, max_memory_reserved = get_paddle_memory_info()
            memory_metrics.append([memory_allocated, max_memory_allocated, memory_reserved, max_memory_reserved])
            
        # 计算平均显存使用情况
        avg_memory = np.mean(memory_metrics, axis=0)
        
        print(
            f"Use fp16: {'true' if args.use_fp16 else 'false'}, "
            f"Mean iter/sec: {1 / (np.mean(time_costs) / args.inference_steps):2f} it/s, "
            f"average end-to-end time :  {np.mean(time_costs)*1000 :2f} ms."
        )
        print(f"GPU max_memory_allocated: {paddle.device.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")

        images[0].save(f"{folder}/text2img.png")


if __name__ == "__main__":
    args = parse_arguments()
    main(args)
